Conversation
|
What an unexpected and amazing surprise! I'm absolutely thrilled. |
|
@awni |
|
I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it. |
|
I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat. |
|
Stole my idea :( |
|
How is this even possible for such an awesome PR to be left like this? |
There was a problem hiding this comment.
Pull request overview
This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.
Changes:
- Added ROCm backend infrastructure with device management, memory allocation, and stream handling
- Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
- Updated build system (CMake) to support ROCm compilation with configurable GPU architectures
Reviewed changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| CMakeLists.txt | Added MLX_BUILD_ROCM option and ROCm library detection |
| mlx/CMakeLists.txt | Integrated ROCm backend build configuration |
| mlx/device.cpp | Added ROCm device availability checks |
| mlx/backend/rocm/*.hip | HIP kernel implementations for various operations |
| mlx/backend/rocm/device.* | ROCm device and stream management |
| mlx/backend/rocm/allocator.* | ROCm-specific memory allocator using HIP unified memory |
| mlx/backend/rocm/worker.* | Async task execution worker for stream synchronization |
| mlx/backend/rocm/utils.* | HIP utility functions and error handling |
| mlx/backend/rocm/jit_module.* | JIT compilation support using HIPRTC |
| mlx/backend/rocm/device/*.hpp | Device-side utility functions and type definitions |
| mlx/backend/rocm/CMakeLists.txt | ROCm backend build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ather, scatter, logsumexp, random bits generation, and sorting. Introduce new kernels for efficient computation and integrate with existing ROCm utilities. Update CMake configuration to include new source files and dependencies. Enhance error handling and ensure compatibility with different data types. This commit significantly expands the functionality of the ROCm backend.
|
👑👑👑 |
|
Can anyone run CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .Replace {based on your GPU} with your GPU architecture You can run rocm-smito get your GPU information |
|
I'm getting this CMake error: Running on Strix Halo (gfx1151) |
Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅 |
… string formatting, replacing fmt library usage. Remove unused event.cpp file. Update kernel name generation and parameter formatting for consistency.
Now what can I test? 😍 |
|
I'm getting this: |
I forgot to test the Python build my bad, can you try it now? Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works |
Root cause: ensure_row_contiguous_matrix only checked last 2 dimensions. Arrays from expand_dims (SwitchGLU MoE path) had non-contiguous batch strides that passed the check but caused OOB when the kernel used flat pointer arithmetic (x + lhs_idx * M * K). Fix: - GatherQMM::eval_gpu: use ensure_row_contiguous (full contiguity check) for all inputs, not just ensure_row_contiguous_matrix (last-2-dims) - Add LHS_B parameter (valid x batch count) to both gather kernels - Add bounds clamping: lhs_idx < LHS_B, rhs_idx < E - QuantizedMatmul (non-gather) unchanged — no batch indirection
RMSNorm (called 72x per forward pass): - Replace rsqrtf() hardware approximation with 1.0f/sqrtf() for IEEE compliance (Metal uses precise::rsqrt) - Match Metal's weight application order: truncate to T between normalization and weight multiply (intermediate rounding step) - Same fix applied to LayerNorm Sort/ArgSort: - Add is_sort_floating_v trait that includes __half and hip_bfloat16 (std::is_floating_point_v is false for these, skipping NaN handling) - Fix NaN comparison and sentinel values for half types - Add __half nan_value specialization SDPA: - Fix max_score initialization: use Limits<U>::finite_min (-FLT_MAX) instead of -1e9f (matches Metal) - Fix zero-sum normalization edge case Standalone ops (binary_ops.hpp, unary_ops.hpp): - Promote __half and hip_bfloat16 through float for Add, Subtract, Multiply, Divide (Metal auto-promotes, ROCm doesn't) - Add float promotion for unary ops with __half inputs JIT preamble (compiled.cpp): - Remove redundant float promotion for Add/Subtract/Multiply/Divide (already promoted in previous commit, clean up duplicate logic)
The non-uniform-stride batch loop in gemm_and_bias() called rocBLAS directly (bypassing the naive_gemm wrapper that was patched earlier) and only handled float32/float64 — bfloat16 and float16 matmuls silently did nothing, leaving the output buffer uninitialized. This caused non-deterministic SDPA results for any GQA model (where n_q_heads != n_kv_heads) at sequence lengths >= 4, with progressively worse corruption (NaN/Inf at L >= 7). The SDPA fallback decomposition reshapes Q via unflatten and K/V via expand_dims for GQA broadcasting, which produces non-uniform batch strides that hit this code path. Fix: always use naive_gemm_with_offset for the non-uniform-stride batch loop, matching the approach already used by the single-GEMM and strided-batched paths.
The supports_sdpa_vector() function listed head_dim=256 as supported, but the sdpa_vector() dispatch only had cases for D=64, 96, 128. For D=256, no kernel was launched, leaving the output buffer uninitialized — causing non-deterministic results for models using head_dim=256 (e.g. Qwen3-Next) at sequence lengths 1-3.
|
I have a lot of changes to merge in I am testing my port of mlx-swift-lm https://github.com/lemonade-sdk/lemon-mlx-engine against the mlx rocm core, https://github.com/lemonade-sdk/lemon-mlx-core-amd I got qwen3 working. I am working on Qwen3Next right now, its having weird issues. There are tons of problems with the rocm backend that I have traced to "different rounding" causing unstable outputs. But a lot of it is fixed now at least regarding qwen models. Once I get your changes merged into my repo I will then push a PR into yours with my changes. I have made optimizations as well, there are problems with the fallback system when functions in rocBLAS arn't compatible or existent for the architecture. |
|
Once I get Qwen3Next working at a reasonable speed I will do the PR. |
Merges goniz/rocm-support-fixes: flash attention kernel, allocator redesign for integrated GPUs, bfloat16 math overloads, QMV vectorization, depthwise conv1d, event sync improvements, rocBLAS solution-index dispatch, and upstream main (CUDA, docs, quantization). Conflicts resolved preferring upstream for most ROCm backend files, keeping our SliceUpdate kernel and float-promotion JIT approach.
The gather_qmv_warp_shared_kernel (wave-cooperative, shared memory tiling, vectorized 4-bit unpacking) was only dispatched for 6-bit and 8-bit quantization. 4-bit fell through to the naive gather_qmv_kernel (1 thread per output, sequential K loop), which was 18.6x slower. Add bits==4 to the fast dispatch condition. The kernel already handles 4-bit internally with 8-element vectorized unpacking. Profiled impact (Qwen3-Next 4-bit MoE): gather_qmv_kernel: 5193 μs/call → (removed) gather_qmv_warp_shared_kernel: N/A → 279 μs/call (18.6x)
Key changes for Strix Halo / RDNA 3.5 integrated GPU: 1. raw_ptr(): Use hipStreamSynchronize(nullptr) instead of hipDeviceSynchronize() for unified memory buffers. Only waits on the default stream instead of all streams. Skips the expensive move_to_unified_memory() since integrated GPU memory is already CPU-accessible (device==-1). 2. malloc(): Integrated GPU path now goes through rocm_unified_malloc() which sets device=-1, so raw_ptr() takes the fast path. 3. rocm_unified_malloc(): Integrated GPUs try hipExtMallocWithFlags (fine-grained coherent) first, falling back to hipMallocManaged. Profiled impact on Qwen3-Next 4-bit MoE: Generation: 12.0 tok/s → 18.9 tok/s (58% faster) Prompt: 2.5 tok/s → 5.2 tok/s (2x faster)
The noshared QMV kernel reads x from global memory redundantly per warp (each warp reloads the same x vector). The shared variant caches x in LDS and is significantly faster for decode-sized (M<=8) shapes. Disable the alignment-based noshared path selection; always use the shared variant unless K is tiny. This reduces redundant global memory traffic for dense quantized projections.
For MoE prefill (M>1) with sorted rhs_indices, consecutive batch elements map to the same expert. The existing gather_qmv_warp_shared kernel launches B independent blocks that each load the same expert weights from global memory — 60-75x redundant weight traffic. New gather_qmv_prefill_kernel groups batch elements into contiguous runs of same-expert assignments. Each block handles one (run, row, col) and iterates over all batch elements in the run, reading weights once. Grid z-dimension = num_runs (~8-10 unique experts) instead of B (~600). Supports 4-bit and 8-bit affine quantization with vectorized unpacking (8 elements per iteration for 4-bit, 4 for 8-bit) and fmaf accumulation. Profiled impact (Qwen3-Next 4-bit MoE, 40-token prompt): Prompt: 1.8 tok/s → 6.1 tok/s (3.4x faster) gather_qmv total: 502ms → ~150ms
New gather_qmv_wmma_prefill_kernel uses rocWMMA 16x16x16 bf16→f32 tiles for matrix multiply-accumulate during MoE prefill. Each wave32 handles a 16x16 output tile, dequantizing 4-bit weights into shared memory and using rocwmma::mma_sync for the reduction. Enabled for gfx11 (RDNA 3/3.5) and gfx12 (RDNA 4) when M >= 16 and dimensions are 16-aligned. Falls back to scalar kernel otherwise. Guarded by ROCM_HAS_WMMA macro so gfx9/gfx10 builds are unaffected. Also restores hipExtMallocWithFlags as primary allocator for APU (reverts hipMallocManaged experiment — fine-grained coherent gives better GPU kernel bandwidth). Profiled impact (Qwen3-Coder-Next 4-bit, Strix Halo gfx1151): Prompt (40 tok): 84 tok/s → 117 tok/s (39% faster) Qwen3-8B prompt: 33 tok/s → 44 tok/s (33% faster) Generation: unchanged at ~18 tok/s
- Remove M%16 alignment requirement: kernel now bounds-checks rows, padding with zero for tile positions beyond M. - Remove right_sorted_ requirement from prefill dispatch: CPU-side sort creates sorted index arrays and output permutation for any index order. - Add out_perm parameter to both WMMA and scalar prefill kernels to scatter results back to original batch positions after sorted dispatch. - Add <algorithm> and <numeric> includes for std::sort/std::iota. NOTE: MLX's MoE layer (SwitchGLU) currently expands all tokens to individual M=1 calls via gather_qmm. The prefill kernels (M>1) will activate when upstream changes batch tokens per-expert. The 4-bit fast gather_qmv_warp_shared dispatch handles the current M=1 path.
New gather_qmv_expert_batched_kernel finds expert run boundaries on-GPU via binary search of sorted rhs_indices. Each block handles one (expert, column) pair and iterates over all tokens for that expert, loading weights once per expert. Dispatch condition: E <= 64 and B/E >= 4 (low expert count with many tokens per expert). For high-expert models (E=512 like Qwen3-Next), the warp_shared kernel remains faster since most runs have only 1-4 tokens and the per-block run-finding overhead isn't justified.
hipBLASLt provides architecture-tuned GEMM kernels via Tensile, typically outperforming rocBLAS for bf16/fp16 on RDNA 3.5 and CDNA. New hipblaslt_gemm() and hipblaslt_gemm_batched() functions with: - Per-device handle cache (thread-safe, lazily initialized) - Algorithm heuristic selection (best-of-1 from hipBLASLt) - RAII guards for all descriptor types - Persistent workspace allocation (up to 32MB, grown as needed) - fp32 accumulation for bf16/fp16 inputs matmul.cpp tries hipBLASLt first for bf16/fp16, falls back to rocBLAS silently on failure. Float32/64 GEMMs unchanged.
The dequant+GEMM path in QuantizedMatmul now tries hipBLASLt before rocBLAS for bf16 GEMMs. hipBLASLt selects architecture-tuned kernels via heuristic algorithm search, significantly outperforming rocBLAS once the algorithm cache is warm. New hipblaslt_gemm_raw() allows calling from inside kernel lambdas with pre-swapped column-major parameters, matching the rocBLAS pattern. Warm prompt (Qwen3-Coder-Next 4-bit, Strix Halo): 80 tok/s → 207 tok/s (2.6x faster) First-call overhead from algorithm search is amortized by the application warmup pass.
- hipblaslt_gemm_raw() for calling from inside kernel lambdas with pre-swapped col-major params. Used in QMM bf16 dequant+GEMM path. - Warm prompt: 80→207 tok/s with hipBLASLt algorithm cache primed. - CommandEncoder graph capture API (begin_capture, end_capture, replay, reset_graph) using hipStreamBeginCapture/EndCapture/GraphLaunch. Infrastructure for future decode acceleration (18→34 tok/s potential). Not yet active due to MLX lazy eval incompatibility with capture mode.
Replace the 5-operation copy chain (2 allocs + 2 hipMemcpyAsync + 1 kernel) with single-dispatch strided copy kernels for non-contiguous arrays. New kernels: - strided_row_copy_kernel: inner-contiguous with outer stride gap (common pattern from take/gather_sort). Uses 4-byte word copies when aligned. - strided_general_copy_kernel: arbitrary strides, shapes/strides passed as by-value structs (zero device allocation). Tiered dispatch in ensure_row_contiguous_matrix: 1. Already contiguous → return (fast path, unchanged) 2. Inner-contiguous outer gap → strided_row_copy_kernel (1 dispatch) 3. General non-contiguous → strided_general_copy_kernel (1 dispatch) 4. ndim > 10 → old contiguous_copy_gpu fallback Net: each non-contiguous copy drops from 5 GPU operations to 1.
Coarser size buckets for large allocations improve buffer cache hit rate during LLM decode. Without this, slightly different allocation sizes (e.g., 1.01MB vs 1.02MB) miss the cache and trigger hipExtMallocWithFlags at ~7ms each. Previous: page-aligned (16KB granularity) for all sizes >= 16KB New: page-aligned for 16KB-1MB, power-of-2 for >= 1MB Trades up to 2x memory waste for large buffers in exchange for dramatically fewer cache misses during steady-state decode.
The power-of-2 rounding for >= 1MB allocations caused OOM by doubling large allocations that exceeded the 2GB device-local VRAM on iGPU. Reverted to page-aligned (16KB) rounding for all large sizes. hipExtMallocWithFlags remains the primary path for iGPU (best GPU bandwidth via fine-grained coherent access). Falls back to hipMallocManaged for allocations that exceed VRAM capacity, accessing the full system RAM (126GB on Strix Halo).
I have added you as a collaborator on my fork, you should be able to push changes directly to this branch(should be able to push changes directly to this PR). Again amazing work🚀 |
ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator
|
We are still 2* slower on almost all things, tps, and pp. But its workable right now and usable actually. I am going to look futher into how the mac eval and allocator works to see if i can get some ideas for optimizing allocation. The allocations if you benchmark and profile it are off the chain, and thanks for the invite! |
|
@goniz do you want to test this again? We should be at a point where its working. |
|
@NripeshN if you are interested this is a very interesting read. |
Experiment with ROCm backend.
install MLX with ROCm backend using:
closes #2556
Inspired by @zcbenz